# !/usr/bin/env python
# -*- coding:utf-8 -*-

import math

import torch
from torch import nn
import numpy as np
from torch.nn import functional as F
import networkx as nx

from tools.utils import _h_A


class NonlinearTransforms(nn.Module):
    """docstring for InvertiblePrior

        1/(1+exp{-(wx+b)})
    """
    def __init__(self, in_dim, out_dim, activation):
        super(NonlinearTransforms, self).__init__()

        self.FC1 = nn.Linear(in_dim, out_dim//4)
        self.FC2 = nn.Linear(out_dim//4, in_dim)
        self.activation = activation

    def forward(self, eps):
        '''

        @param eps:
        @return:
        '''
        # o = F.linear(eps, self.W, self.bias)
        # o = 1/(1+torch.exp(-o))

        o = self.FC2(self.activation(self.FC1(eps)))
        # o = self.sigmoid(self.FC1(eps))
        return o

class SCM(nn.Module):
    def __init__(self, in_dim, hidden_dim, hidden_num, scm_type='nonlinear',
                 nonlinear_activation=nn.ReLU()):
        super().__init__()
        self.alpha = 3
        self.hidden_dim = hidden_dim
        self.hidden_num = hidden_num  # num_label
        self.nonlinear_activation = nonlinear_activation
        # self.Weight_DAG = nn.Parameter(torch.randn((hidden_num, hidden_num))*0.29)
        self.Weight_DAG = nn.Parameter(torch.tensor(self.create_Weight_DAG(self.hidden_num)).clone().detach())
        
        self.register_zero_grad_hook()

        # Elementwise nonlinear mappings
        if scm_type == 'linear':
            transforms = nn.Identity()
        elif scm_type == 'nonlinear':
                transforms = NonlinearTransforms
        else:
            raise NotImplementedError("Not supported prior network.")

        for i in range(self.hidden_num):
            setattr(self, "transforms%d" % i, transforms(int(in_dim/self.hidden_num), hidden_dim, self.nonlinear_activation))

    #设置下三角及主对角线梯度为0
    def register_zero_grad_hook(self):
        def hook(grad):
            
            mask = torch.tril(torch.ones_like(grad), diagonal=0)
            grad.data[mask.bool()] = 0
            return grad
        
        self.Weight_DAG.register_hook(hook)
        
    def create_Weight_DAG(self, num):
        
        upper_triangle = torch.rand(num,num)
        upper_triangle = torch.triu(upper_triangle, 1)
        lower_triangle = torch.empty(3, 3).uniform_(-0.3, 0.3)
        lower_triangle = torch.tril(lower_triangle)
        return upper_triangle + lower_triangle
      
    def generate_z(self, eps):
        '''
        h = (I-A.T)^{-1}*eps

        z = f(h)
        @param eps: [batch, num, dim]
        @return:
        '''
        # to amplify the value of A and accelerate convergence.
        # self.amplif_Weight_DAG = F.relu(torch.tanh(self.alpha*self.Weight_DAG))
        self.amplif_Weight_DAG = torch.sinh(self.alpha * self.Weight_DAG)
        print("DAG", self.Weight_DAG)
        
        I = torch.eye(self.amplif_Weight_DAG.shape[0], device=self.Weight_DAG.device)
        DAG_normalized = torch.inverse(I - self.amplif_Weight_DAG.t())
        h = torch.matmul(DAG_normalized, eps)
        
        # print("h", h)
        
        # nonlinear transform
        h = torch.split(h, 1, dim=1)
        zs = []
        for i in range(self.hidden_num):
            zs.append(getattr(self, "transforms%d" % i)(h[i]))

        z = torch.cat(zs, dim=1)

        return z
    
    def cal_loss(self):

        loss = 0
        # add A loss
        one_adj_A = self.amplif_Weight_DAG

        # compute h(A)
        h_A = _h_A(one_adj_A, one_adj_A.shape[0])
        loss += h_A + 0.5 * h_A * h_A + 100. * torch.trace(
                one_adj_A * one_adj_A)  # +  0.01 * torch.sum(variance * variance)

        return loss

    def forward(self, eps):
        '''
        @param eps:
        @param z:
        @return:
        '''
        eps = eps.reshape(eps.shape[0], self.hidden_num, int(eps.shape[1]/self.hidden_num))
        
        z = self.generate_z(eps)         # n x d  （B*3）
        
        z = z.flatten(start_dim=1)
        
        loss = self.cal_loss()
        return z, loss